package com.add

import org.apache.spark.rdd.RDD
import org.apache.spark.util.AccumulatorV2
import org.apache.spark.{SparkConf, SparkContext}


/**
 * 自定义Int类型的累加器
 */
object MyIntAccumulator {
  def main(args: Array[String]): Unit = {
    val conf: SparkConf = new SparkConf().setAppName("Add").setMaster("local[2]")
    val sc: SparkContext = new SparkContext(conf)
    val list1 = List(30, 50, 70, 60, 10, 20)
    val rdd1: RDD[Int] = sc.parallelize(list1, 2)

    // 先注册自定义的累加器
    val acc = new MyIntAccumulator
    sc.register(acc, "first")

    val rdd2: RDD[Int] = rdd1.map(x => {
      acc.add(1)
      x
    })
    rdd2.collect
    println(acc.value)
    sc.stop()
  }
}


/**
 * 自定义int累加器
 * 泛型的意思: 第一个是调用这个累加器传什么值, 第二个泛型意思是返回的结果是什么类型的
 * 比如说对int值进行累加,那么第一个值就是int类型的,
 * 这个累加器返回int值,那么第二个泛型就是int类型的.
 */
class MyIntAccumulator extends AccumulatorV2[Int, Int] {
  private var sum = 0

  /**
   * 判断是不是"零", 对缓冲区值进行判"零"
   * 当然这个具体得看业务了,比如说集合,那么就是判空集合,
   * 如果是map累加器那么就是空map
   * 如果是字符串的累加器,那么就是判断空字符串
   * 所以具体得看业务了.
   *
   * @return
   */
  override def isZero: Boolean = sum == 0

  /** 把当前的累加复制为一个新的累加器
   *
   * @return
   */
  override def copy(): AccumulatorV2[Int, Int] = {
    val acc = new MyIntAccumulator
    acc.sum = sum
    acc
  }

  /** 重置累加器(就是把缓冲区的值重置为"零")
   *
   */
  override def reset(): Unit = sum = 0


  /** 真正的累加方法,这个是分区内的累加,多个分区各自累加
   *
   * @param v
   */
  override def add(v: Int): Unit = sum += v

  /** 分区间的合并  把other的sum合并到this的sum中
   * 把所有分区的值再累加在一起.
   *
   * @param other
   */
  override def merge(other: AccumulatorV2[Int, Int]): Unit = other match {
    case acc: MyIntAccumulator => this.sum += acc.sum
    case _ => this.sum += 0
  }

  /** 返回累加后的最终值
   *
   * @return
   */
  override def value: Int = sum
}
